'''
Class for handling FAISS database operations for question similarity search.
Includes initialization, searching, and utility functions.
Uses SentenceTransformers for embeddings and FAISS for vector search.
'''

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Tuple
import re

# Global variables
model = None
index = None
questions_db = [] # Store (id, question) tupels

def initialize_database(questions: List[Tuple]) -> bool:
    '''
    Initialize the FAISS database with questions fetched from AdvisorQuestions in the database
    Loads in the sentance tranformer model
    Returns true if initialization is successful, false otherwise
    '''
    try: 
        global model, index, questions_db
        # load the sentance tranformer model
        model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
        questions_db = questions.copy()
        # Create embeddings and build faiss index
        texts = [q[1] for q in questions]
        embeddings = model.encode(texts)
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatL2(dimension) # L2 distance (Euclidean)
        index.add(np.array(embeddings, dtype=np.float32))
        return True # Initialization succesful
    except Exception as e:
        print(f"Failed to initialize database: {e}")
        return False

def calculate_similarity_threshold(distances: np.ndarray) -> float:
    '''
    Calculate a dynamic similarity threshold based on the distribution of the distances
    '''
    if len(distances) == 0:
        return float('inf')
    # Determine mean and std for checking relevance
    mean_distance = np.mean(distances)
    std_distance = np.std(distances)
    # Only accept results within 1 std of best match, or within 2x minimum distance
    # Whichever option is more restrictive is the one chosen. 
    threshold = min(mean_distance - 0.5*std_distance, distances[0]*2.5)
    return threshold

def extract_keywords(text: str) -> set:
    ''' 
    Extract keywords for better matching. Removes common words which are not relevant to matching.
    '''
    stop_words = {
        'i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 'yours',
        'yourself', 'yourselves', 'he', 'him', 'his', 'himself', 'she', 'her', 'hers',
        'herself', 'it', 'its', 'itself', 'they', 'them', 'their', 'theirs', 'themselves',
        'what', 'which', 'who', 'whom', 'this', 'that', 'these', 'those', 'am', 'is', 'are',
        'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'having', 'do', 'does',
        'did', 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 'because', 'as', 'until',
        'while', 'of', 'at', 'by', 'for', 'with', 'about', 'against', 'between', 'into',
        'through', 'during', 'before', 'after', 'above', 'below', 'up', 'down', 'in', 'out',
        'on', 'off', 'over', 'under', 'again', 'further', 'then', 'once', 'can', 'could',
        'should', 'would', 'will'
    }
    # Extract words (3+ characters, alphanumeric)
    words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower())
    keywords = {word for word in words if word not in stop_words}
    return keywords

def calculate_keyword_overlap(query: str, candidate: str) -> float:
    '''
    Calculate keyword overlap (similarity) between query and the candidate in question
    '''
    query_keywords = extract_keywords(query)
    candidate_keywords = extract_keywords(candidate)
    if not query_keywords: # Question was only stop_words, (most likely irrelevent) return no similarity
        return 0.0
    overlap = len(query_keywords.intersection(candidate_keywords))
    return overlap / len(query_keywords)

def search_query(query: str, k: int = 5, min_similarity: float = 0.15, max_results: int = 3) -> List[int]:
    '''
    Enhanced search with relevance filtering and multiple ranking criteria.
    
    Args:
        query: The search query string
        k: Initial number of candidates to retrieve from FAISS (search wider first)
        min_similarity: Minimum keyword overlap required (0.0-1.0)
        max_results: Maximum number of results to return
    
    Returns:
        List of question IDs of the most relevant questions
    '''
    global model, index, questions_db
    try: 
        if model is None or index is None:
            raise Exception("Database not initialized. Call initialize_database() first.")
        if not questions_db:
            raise Exception("No questions in database.")
        # Search more candidates initially to have options for filtering
        search_k = min(k * 2, len(questions_db))  
        query_vector = model.encode([query]).astype(np.float32) 
        distances, indices = index.search(query_vector, search_k)
        distances = distances[0]  
        indices = indices[0]     
        # Calculate dynamic threshold based on distance distribution
        similarity_threshold = calculate_similarity_threshold(distances)
        # Filter and rank candidates
        candidates = []
        for i, (distance, idx) in enumerate(zip(distances, indices)):
            if distance > similarity_threshold:
                continue  # Skip if too dissimilar
            question_id, question_text = questions_db[idx]
            keyword_overlap = calculate_keyword_overlap(query, question_text)
            # Skip if no meaningful keyword overlap
            if keyword_overlap < min_similarity:
                continue
            # Combined score 
            # Normalize distance to 0-1 scale and invert it (so higher is better)
            normalized_distance_score = max(0, 1 - (distance / max(distances)))
            combined_score = 0.6 * normalized_distance_score + 0.4 * keyword_overlap
            candidates.append({
                'id': question_id,
                'question': question_text,
                'distance': distance,
                'keyword_overlap': keyword_overlap,
                'combined_score': combined_score,
                'original_rank': i
            })
        # Sort by combined score
        candidates.sort(key=lambda x: x['combined_score'], reverse=True)
        # Return top results up to max_results
        result_ids = [c['id'] for c in candidates[:max_results]]
        # Debug information
        if candidates:
            print(f"Query: '{query}'")
            print(f"Found {len(candidates)} relevant candidates, returning top {len(result_ids)}")
            for i, c in enumerate(candidates[:max_results]):
                print(f"  {i+1}. ID {c['id']}: score={c['combined_score']:.3f}, "
                      f"keywords={c['keyword_overlap']:.2f}, distance={c['distance']:.3f}")
                print(f"     Question: {c['question'][:100]}...")
        else:
            print(f"No relevant questions found for query: '{query}'")
        return result_ids
    except Exception as e:
        print(f"Error during search: {e}")
        return []

def get_database_stats() -> dict:
    '''
    Get statistics about the current database.
    '''
    global model, index, questions_db
    return {
        "total_questions": len(questions_db),
        "model_loaded": model is not None,
        "index_initialized": index is not None,
        "dimension": index.d if index is not None else 0,
        "index_size": index.ntotal if index is not None else 0
    }

def is_initialized() -> bool:
    '''
    Check if the database is initialized and ready for requests. 
    '''
    global model, index, questions_db
    return model is not None and index is not None and len(questions_db) > 0